def get_deepspeed_config(args):
    ds_config = {}

    ds_config["steps_per_print"] = 1000

    ds_config["optimizer"] = {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto",
        },
    }

    # ds_config["scheduler"] = {
    #     "type": "WarmupLR",
    #     "params": {
    #         "warmup_min_lr": "auto",
    #         "warmup_max_lr": "auto",
    #         "warmup_num_steps": "auto",
    #     },
    # }

    ds_config["gradient_accumulation_steps"] = "auto"
    ds_config["gradient_clipping"] = "auto"
    ds_config["train_batch_size"] = "auto"
    ds_config["train_micro_batch_size_per_gpu"] = "auto"

    if args.use_zero2:
        ds_config["zero_optimization"] = {
            "stage": 2,
            "allgather_partitions": True,
            "allgather_bucket_size": 2e8,
            "overlap_comm": True,
            "reduce_scatter": True,
            "reduce_bucket_size": 2e8,
            "contiguous_gradients": True,
            "offload_optimizer": {"device": "cpu", "pin_memory": True},
        }
    elif args.use_zero3:
        ds_config["zero_optimization"] = {
            "stage": 3,
            "offload_optimizer": {"device": "cpu", "pin_memory": True},
            "offload_param": {"device": "cpu", "pin_memory": True},
            "overlap_comm": True,
            "contiguous_gradients": True,
            "sub_group_size": 1e9,
            "reduce_bucket_size": "auto",
            "stage3_prefetch_bucket_size": "auto",
            "stage3_param_persistence_threshold": "auto",
            "stage3_max_live_parameters": 1e9,
            "stage3_max_reuse_distance": 1e9,
            "stage3_gather_16bit_weights_on_model_save": True,
        }

    if args.offload_params:
        ds_config["zero_optimization"]["offload_param"] = {
            "device": "cpu",
            "pin_memory": True,
        }

    if args.fp16:
        ds_config["fp16"] = {
            "enabled": True,
            "loss_scale": 0.0,
            "loss_scale_window": 1000,
            "hysteresis": 2,
            "min_loss_scale": 1,
            "initial_scale_power": 32,
        }

    if args.bf16:
        ds_config["bf16"] = {"enabled": True}

    return ds_config
